{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "xla_flags = os.environ.get(\"XLA_FLAGS\", \"\")\n",
    "xla_flags += \" --xla_gpu_triton_gemm_any=True\"\n",
    "os.environ[\"XLA_FLAGS\"] = xla_flags\n",
    "os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"] = \"false\"\n",
    "os.environ[\"MUJOCO_GL\"] = \"egl\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import functools\n",
    "import json\n",
    "from datetime import datetime\n",
    "\n",
    "import jax\n",
    "import jax.numpy as jp\n",
    "import pickle\n",
    "import matplotlib.pyplot as plt\n",
    "import mediapy as media\n",
    "import mujoco\n",
    "import numpy as np\n",
    "import wandb\n",
    "from brax.training.agents.ppo import networks as ppo_networks\n",
    "from brax.training.agents.ppo import train as ppo\n",
    "from etils import epath\n",
    "from flax.training import orbax_utils\n",
    "from IPython.display import clear_output, display\n",
    "from orbax import checkpoint as ocp\n",
    "\n",
    "from mujoco_playground import locomotion, wrapper\n",
    "from mujoco_playground.config import locomotion_params\n",
    "\n",
    "# Enable persistent compilation cache.\n",
    "jax.config.update(\"jax_compilation_cache_dir\", \"/tmp/jax_cache\")\n",
    "jax.config.update(\"jax_persistent_cache_min_entry_size_bytes\", -1)\n",
    "jax.config.update(\"jax_persistent_cache_min_compile_time_secs\", 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# env_name = \"BerkeleyHumanoidJoystickRoughTerrain\"\n",
    "# env_name = \"BerkeleyHumanoidJoystickFlatTerrain\"\n",
    "env_name = \"G1JoystickFlatTerrain\"\n",
    "env_cfg = locomotion.get_default_config(env_name)\n",
    "randomizer = locomotion.get_domain_randomizer(env_name)\n",
    "ppo_params = locomotion_params.brax_ppo_config(env_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "env_cfg.reward_config.scales.energy = 0.0\n",
    "env_cfg.reward_config.scales.dof_acc = 0.0\n",
    "\n",
    "# ppo_params.num_timesteps = 50_000_000\n",
    "# ppo_params.num_evals = 5\n",
    "# ppo_params.learning_rate = 1e-3\n",
    "# env_cfg.action_scale = 1.0\n",
    "# env_cfg.command_config.a = [1.5, 0.8, jp.pi]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pprint import pprint\n",
    "\n",
    "pprint(ppo_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setup wandb logging.\n",
    "USE_WANDB = False\n",
    "\n",
    "if USE_WANDB:\n",
    "  wandb.init(project=\"mjxrl\", config=env_cfg)\n",
    "  wandb.config.update({\n",
    "      \"env_name\": env_name,\n",
    "  })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SUFFIX = None\n",
    "\n",
    "FINETUNE_PATH = None\n",
    "\n",
    "# Generate unique experiment name.\n",
    "now = datetime.now()\n",
    "timestamp = now.strftime(\"%Y%m%d-%H%M%S\")\n",
    "exp_name = f\"{env_name}-{timestamp}\"\n",
    "if SUFFIX is not None:\n",
    "  exp_name += f\"-{SUFFIX}\"\n",
    "print(f\"{exp_name}\")\n",
    "\n",
    "# Possibly restore from the latest checkpoint.\n",
    "if FINETUNE_PATH is not None:\n",
    "  FINETUNE_PATH = epath.Path(FINETUNE_PATH)\n",
    "  latest_ckpts = list(FINETUNE_PATH.glob(\"*\"))\n",
    "  latest_ckpts = [ckpt for ckpt in latest_ckpts if ckpt.is_dir()]\n",
    "  latest_ckpts.sort(key=lambda x: int(x.name))\n",
    "  latest_ckpt = latest_ckpts[-1]\n",
    "  restore_checkpoint_path = latest_ckpt\n",
    "  print(f\"Restoring from: {restore_checkpoint_path}\")\n",
    "else:\n",
    "  restore_checkpoint_path = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ckpt_path = epath.Path(\"checkpoints\").resolve() / exp_name\n",
    "ckpt_path.mkdir(parents=True, exist_ok=True)\n",
    "print(f\"{ckpt_path}\")\n",
    "\n",
    "with open(ckpt_path / \"config.json\", \"w\") as fp:\n",
    "  json.dump(env_cfg.to_dict(), fp, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_data, y_data, y_dataerr = [], [], []\n",
    "ep_length_mean, ep_length_std = [], []\n",
    "times = [datetime.now()]\n",
    "\n",
    "fig, axes = plt.subplots(1, 2, figsize=(12, 4))\n",
    "\n",
    "\n",
    "def progress(num_steps, metrics):\n",
    "  # Log to wandb.\n",
    "  if USE_WANDB:\n",
    "    wandb.log(metrics, step=num_steps)\n",
    "\n",
    "  # Plot.\n",
    "  clear_output(wait=True)\n",
    "  times.append(datetime.now())\n",
    "  x_data.append(num_steps)\n",
    "  y_data.append(metrics[\"eval/episode_reward\"])\n",
    "  y_dataerr.append(metrics[\"eval/episode_reward_std\"])\n",
    "  ep_length_mean.append(metrics[\"eval/avg_episode_length\"])\n",
    "  ep_length_std.append(metrics[\"eval/avg_episode_length_std\"])\n",
    "\n",
    "  axes[0].set_xlim([0, ppo_params.num_timesteps * 1.25])\n",
    "  axes[0].set_xlabel(\"# environment steps\")\n",
    "  axes[0].set_ylabel(\"reward per episode\")\n",
    "  axes[0].set_title(f\"y={y_data[-1]:.3f}\")\n",
    "  axes[0].errorbar(x_data, y_data, yerr=y_dataerr, color=\"blue\")\n",
    "\n",
    "  axes[1].set_xlim([0, ppo_params.num_timesteps * 1.25])\n",
    "  axes[1].set_xlabel(\"# environment steps\")\n",
    "  axes[1].set_ylabel(\"episode length\")\n",
    "  axes[1].set_title(f\"y={ep_length_mean[-1]:.3f}\")\n",
    "  axes[1].errorbar(x_data, ep_length_mean, yerr=ep_length_std, color=\"blue\")\n",
    "\n",
    "  display(plt.gcf())\n",
    "\n",
    "\n",
    "def policy_params_fn(current_step, make_policy, params):\n",
    "  del make_policy  # Unused.\n",
    "  orbax_checkpointer = ocp.PyTreeCheckpointer()\n",
    "  save_args = orbax_utils.save_args_from_target(params)\n",
    "  path = ckpt_path / f\"{current_step}\"\n",
    "  orbax_checkpointer.save(path, params, force=True, save_args=save_args)\n",
    "\n",
    "training_params = dict(ppo_params)\n",
    "del training_params[\"network_factory\"]\n",
    "\n",
    "train_fn = functools.partial(\n",
    "  ppo.train,\n",
    "  **training_params,\n",
    "  network_factory=functools.partial(\n",
    "      ppo_networks.make_ppo_networks,\n",
    "      **ppo_params.network_factory\n",
    "  ),\n",
    "  restore_checkpoint_path=restore_checkpoint_path,\n",
    "  progress_fn=progress,\n",
    "  wrap_env_fn=wrapper.wrap_for_brax_training,\n",
    "  policy_params_fn=policy_params_fn,\n",
    "  randomization_fn=randomizer,\n",
    ")\n",
    "\n",
    "env = locomotion.load(env_name, config=env_cfg)\n",
    "eval_env = locomotion.load(env_name, config=env_cfg)\n",
    "make_inference_fn, params, _ = train_fn(environment=env, eval_env=eval_env)\n",
    "if len(times) > 1:\n",
    "  print(f\"time to jit: {times[1] - times[0]}\")\n",
    "  print(f\"time to train: {times[-1] - times[1]}\")\n",
    "\n",
    "# Make a final plot of reward and success vs WALLCLOCK time.\n",
    "plt.figure()\n",
    "plt.xlabel(\"wallclock time (s)\")\n",
    "plt.ylabel(\"reward per episode\")\n",
    "plt.title(f\"y={y_data[-1]:.3f}\")\n",
    "plt.errorbar(\n",
    "    [(t - times[0]).total_seconds() for t in times[:-1]],\n",
    "    y_data,\n",
    "    yerr=y_dataerr,\n",
    "    color=\"blue\",\n",
    ")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "inference_fn = make_inference_fn(params, deterministic=True)\n",
    "jit_inference_fn = jax.jit(inference_fn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save normalizer and policy params to the checkpoint dir.\n",
    "normalizer_params, policy_params, value_params = params\n",
    "with open(ckpt_path / \"params.pkl\", \"wb\") as f:\n",
    "  data = {\n",
    "    \"normalizer_params\": normalizer_params,\n",
    "    \"policy_params\": policy_params,\n",
    "    \"value_params\": value_params,\n",
    "  }\n",
    "  pickle.dump(data, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mujoco_playground._src.gait import draw_joystick_command\n",
    "\n",
    "# Enable perturbation in the eval env.\n",
    "env_cfg = locomotion.get_default_config(env_name)\n",
    "eval_env = locomotion.load(env_name, config=env_cfg)\n",
    "\n",
    "jit_reset = jax.jit(eval_env.reset)\n",
    "jit_step = jax.jit(eval_env.step)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rng = jax.random.PRNGKey(1)\n",
    "rollout = []\n",
    "modify_scene_fns = []\n",
    "\n",
    "swing_peak = []\n",
    "rewards = []\n",
    "linvel = []\n",
    "angvel = []\n",
    "track = []\n",
    "hand_positions = []\n",
    "pelvis_position = []\n",
    "foot_vel = []\n",
    "rews = []\n",
    "qvels = []\n",
    "\n",
    "x_cmd = 0.0\n",
    "y_cmd = 0.0\n",
    "yaw_cmd = 0\n",
    "command = jp.array([x_cmd, y_cmd, yaw_cmd])\n",
    "\n",
    "power1 = []\n",
    "power2 = []\n",
    "qpos_cost = []\n",
    "\n",
    "rews_ep = []\n",
    "\n",
    "phase_dt = 2 * jp.pi * eval_env.dt * 1.5\n",
    "phase = jp.array([0, jp.pi])\n",
    "\n",
    "for j in range(1):\n",
    "  print(f\"episode {j}\")\n",
    "  state = jit_reset(rng)\n",
    "  # state.info[\"command\"] = command\n",
    "  state.info[\"phase_dt\"] = phase_dt\n",
    "  state.info[\"phase\"] = phase\n",
    "  ep_rews = []\n",
    "  for i in range(env_cfg.episode_length):\n",
    "    # if i % 200 == 0:\n",
    "    #   x_cmd = -x_cmd\n",
    "    #   y_cmd = -y_cmd\n",
    "    #   yaw_cmd = -yaw_cmd\n",
    "    #   command = jp.array([x_cmd, y_cmd, yaw_cmd])\n",
    "    #   print(command)\n",
    "    act_rng, rng = jax.random.split(rng)\n",
    "    ctrl, _ = jit_inference_fn(state.obs, act_rng)\n",
    "    state = jit_step(state, ctrl)\n",
    "    ep_rews.append(state.reward)\n",
    "    if state.done:\n",
    "      print(\"something bad happened\")\n",
    "      break\n",
    "    state.info[\"command\"] = command\n",
    "    rews.append(\n",
    "        {k: v for k, v in state.metrics.items() if k.startswith(\"reward/\")}\n",
    "    )\n",
    "    rollout.append(state)\n",
    "    swing_peak.append(state.info[\"swing_peak\"])\n",
    "    rewards.append(\n",
    "        {k[7:]: v for k, v in state.metrics.items() if k.startswith(\"reward/\")}\n",
    "    )\n",
    "    linvel.append(eval_env.get_local_linvel(state.data))\n",
    "    angvel.append(eval_env.get_gyro(state.data))\n",
    "    track.append(\n",
    "        eval_env._reward_tracking_lin_vel(\n",
    "            state.info[\"command\"], eval_env.get_local_linvel(state.data)\n",
    "        )\n",
    "    )\n",
    "\n",
    "    # hand_positions.append(state.data.site_xpos[eval_env._hands_site_id])\n",
    "    # pelvis_position.append(state.data.site_xpos[eval_env._pelvis_imu_site_id])\n",
    "\n",
    "    feet_vel = state.data.sensordata[eval_env._foot_linvel_sensor_adr]\n",
    "    vel_xy = feet_vel[..., :2]\n",
    "    vel_norm = jp.sqrt(jp.linalg.norm(vel_xy, axis=-1))\n",
    "    foot_vel.append(vel_norm)\n",
    "\n",
    "    frc = state.data.actuator_force\n",
    "    qvel = state.data.qvel[6:]\n",
    "    power1.append(jp.sum(frc * qvel))\n",
    "    power2.append(jp.sum(jp.abs(frc * qvel)))\n",
    "\n",
    "    qvels.append(qvel)\n",
    "    qpos_cost.append(jp.sum(jp.square(state.data.qpos[7:] - eval_env._default_pose)))\n",
    "\n",
    "    xyz = np.array(state.data.xpos[eval_env.mj_model.body(\"torso\").id])\n",
    "    xyz += np.array([0, 0.0, 0])\n",
    "    x_axis = state.data.xmat[eval_env._torso_body_id, 0]\n",
    "    yaw = -np.arctan2(x_axis[1], x_axis[0])\n",
    "    modify_scene_fns.append(\n",
    "        functools.partial(\n",
    "            draw_joystick_command,\n",
    "            cmd=state.info[\"command\"],\n",
    "            xyz=xyz,\n",
    "            theta=yaw,\n",
    "            scl=np.linalg.norm(state.info[\"command\"]),\n",
    "        )\n",
    "    )\n",
    "  rews_ep.append(ep_rews)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "power = jp.array(power1)\n",
    "# Plot smoothed power.\n",
    "power = jp.convolve(power, jp.ones(10) / 10, mode='valid')\n",
    "plt.plot(power)\n",
    "# Plot mean as a horizontal line.\n",
    "plt.axhline(y=jp.mean(jp.array(power)), color='r', linestyle='-', label='Mean')\n",
    "# Plot horizontal line at 0.0.\n",
    "plt.axhline(y=0, color='k', linestyle='--')\n",
    "plt.legend()\n",
    "plt.show()\n",
    "# Print min and max.\n",
    "print(f\"Min power: {jp.min(power)}, Max power: {jp.max(power)}\", \"Mean power: \", jp.mean(power))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "render_every = 1\n",
    "fps = 1.0 / eval_env.dt / render_every\n",
    "print(f\"fps: {fps}\")\n",
    "traj = rollout[::render_every]\n",
    "mod_fns = modify_scene_fns[::render_every]\n",
    "\n",
    "scene_option = mujoco.MjvOption()\n",
    "scene_option.geomgroup[2] = True\n",
    "scene_option.geomgroup[3] = False\n",
    "scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True\n",
    "scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = False\n",
    "scene_option.flags[mujoco.mjtVisFlag.mjVIS_PERTFORCE] = False\n",
    "\n",
    "frames = eval_env.render(\n",
    "    traj,\n",
    "    camera=\"track\",\n",
    "    scene_option=scene_option,\n",
    "    width=640*2,\n",
    "    height=480,\n",
    "    modify_scene_fns=mod_fns,\n",
    ")\n",
    "media.show_video(frames, fps=fps, loop=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "angvel_reward = [r[\"tracking_ang_vel\"] for r in rewards]\n",
    "angvel_reward = jp.array(angvel_reward)\n",
    "angvel_reward = jp.convolve(angvel_reward, jp.ones(10) / 10, mode=\"same\")\n",
    "\n",
    "linvel_reward = [r[\"tracking_lin_vel\"] for r in rewards]\n",
    "linvel_reward = jp.array(linvel_reward)\n",
    "linvel_reward = jp.convolve(linvel_reward, jp.ones(10) / 10, mode=\"same\")\n",
    "\n",
    "plt.plot(angvel_reward, label=\"angvel\")\n",
    "plt.plot(linvel_reward, label=\"linvel\")\n",
    "# Plot the max as a horizontal line.\n",
    "plt.axhline(env_cfg.reward_config.scales.tracking_ang_vel, color=\"k\", linestyle=\"--\")\n",
    "plt.axhline(env_cfg.reward_config.scales.tracking_lin_vel, color=\"k\", linestyle=\"--\")\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot each foot in a 2x2 grid.\n",
    "swing_peak = jp.array(swing_peak)\n",
    "names = [\"r_foot\", \"l_foot\",]\n",
    "colors = [\"r\", \"g\"]\n",
    "fig, axs = plt.subplots(1, 2)\n",
    "for i, ax in enumerate(axs.flat):\n",
    "  ax.plot(swing_peak[:, i], color=colors[i])\n",
    "  ax.set_ylim([0, env_cfg.reward_config.max_foot_height * 1.25])\n",
    "  ax.axhline(env_cfg.reward_config.max_foot_height, color=\"k\", linestyle=\"--\")\n",
    "  ax.set_title(names[i])\n",
    "  ax.set_xlabel(\"time\")\n",
    "  ax.set_ylabel(\"height\")\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "linvel_x = jp.array(linvel)[:, 0]\n",
    "linvel_y = jp.array(linvel)[:, 1]\n",
    "angvel_yaw = jp.array(angvel)[:, 2]\n",
    "\n",
    "# Plot whether velocity is within the command range.\n",
    "linvel_x = jp.convolve(linvel_x, jp.ones(10) / 10, mode=\"same\")\n",
    "linvel_y = jp.convolve(linvel_y, jp.ones(10) / 10, mode=\"same\")\n",
    "angvel_yaw = jp.convolve(angvel_yaw, jp.ones(10) / 10, mode=\"same\")\n",
    "\n",
    "fig, axes = plt.subplots(3, 1, figsize=(10, 10))\n",
    "axes[0].plot(linvel_x)\n",
    "axes[1].plot(linvel_y)\n",
    "axes[2].plot(angvel_yaw)\n",
    "\n",
    "# Draw limits as horizontal lines in black.\n",
    "for i, ax in enumerate(axes):\n",
    "    ax.axhline(-env_cfg.command_config.a[i], color=\"k\", linestyle=\"--\")\n",
    "    ax.axhline(env_cfg.command_config.a[i], color=\"k\", linestyle=\"--\")\n",
    "\n",
    "# Draw the command as a red line.\n",
    "for i, ax in enumerate(axes):\n",
    "    ax.axhline(state.info[\"command\"][i], color=\"r\", linestyle=\"--\")\n",
    "\n",
    "labels = [\"dx\", \"dy\", \"dyaw\"]\n",
    "for i, ax in enumerate(axes):\n",
    "  ax.set_ylabel(labels[i])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
